package com.sl.sdn.repository.impl;

import cn.hutool.core.bean.BeanUtil;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.convert.Convert;
import cn.hutool.core.util.NumberUtil;
import cn.hutool.core.util.StrUtil;
import com.sl.sdn.entity.dto.OrganDTO;
import com.sl.sdn.entity.dto.TransportLineNodeDTO;
import com.sl.sdn.entity.node.AgencyEntity;
import com.sl.sdn.enums.OrganTypeEnum;
import com.sl.sdn.repository.AgencyRepository;
import com.sl.sdn.repository.TransportLineRepository;
import org.neo4j.driver.internal.value.PathValue;
import org.neo4j.driver.types.Path;
import org.springframework.data.neo4j.core.Neo4jClient;
import org.springframework.data.neo4j.core.schema.Node;
import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import java.util.Map;
import java.util.Optional;

@Service
public class TransportLineRepositoryImpl implements TransportLineRepository {

    @Resource
    private Neo4jClient neo4jClient;

    /**
     * 查询两个网点之间最短的路线，查询深度为：10
     *
     * @param start 开始网点
     * @param end   结束网点
     * @return 路线
     */
    @Override
    public TransportLineNodeDTO findShortestPath(AgencyEntity start, AgencyEntity end) {
        //获取网点数据在Neo4j中的类型 如AGENCY
        String type = AgencyEntity.class.getAnnotation(Node.class).value()[0];
        //构造查询语句
        String cypherQuery = StrUtil.format("match path=shortestPath((start:{})-[*..10]->(end:{})) " +
                "where start.bid=$startId and end.bid=$endId return path", type, type);
        //执行查询
        Optional<TransportLineNodeDTO> optional = this.neo4jClient.query(cypherQuery)
                .bind(start.getBid()).to("startId") //设置参数
                .bind(end.getBid()).to("endId") //设置参数
                .fetchAs(TransportLineNodeDTO.class) //响应数据类型
                .mappedBy(((typeSystem, record) -> { //具体转换
                    PathValue pathValue = (PathValue) record.get(0);
                    Path path = pathValue.asPath();
                    TransportLineNodeDTO dto = new TransportLineNodeDTO();
                    //读取节点数据
                    path.nodes().forEach(node -> {
                        Map<String, Object> map = node.asMap();
                        //将map转为OrganDTO 但会少了type和经纬度
                        OrganDTO organDTO = BeanUtil.toBeanIgnoreError(map, OrganDTO.class);
                        //设置第一个标签为类型 （因为只有一个标签） 这里的type是code 而不是具体的字符串
                        organDTO.setType(OrganTypeEnum.valueOf(CollUtil.getFirst(node.labels())).getCode());
                        //查经纬度 从location对象中获取x和y属性
                        organDTO.setLatitude(BeanUtil.getProperty(map.get("location"), "y")); //纬度
                        organDTO.setLongitude(BeanUtil.getProperty(map.get("location"), "x")); //经度
                        dto.getNodeList().add(organDTO);
                    });
                    //提取关系中的cost进行求和 计算总成本
                    path.relationships().forEach(relationship -> {
                        Map<String, Object> objectMap = relationship.asMap();
                        Double cost = Convert.toDouble(objectMap.get("cost"), 0d); //将Map取出来的值转化为double
                        //求和 计算总成本
                        dto.setCost(NumberUtil.add(cost, dto.getCost()));
                    });
                    //取2位小数
                    dto.setCost(NumberUtil.round(dto.getCost(), 2).doubleValue());
                    return dto;
                })).one();
        return optional.orElse(null);
    }

    /**
     * 查询两个网点之间成本最低的路线，查询深度为：10
     *
     * @param start 开始网点
     * @param end   结束网点
     * @return 路线
     */
    public TransportLineNodeDTO findMinCostPath(AgencyEntity start, AgencyEntity end) {
        //获取网点数据在Neo4j中的类型 如AGENCY
        String type = AgencyEntity.class.getAnnotation(Node.class).value()[0];
        //构造查询语句
        String cypherQuery = StrUtil.format("match path=(start:{})-[*..10]->(end:{}) " +
                "where start.bid=$startId and end.bid=$endId " +
                "unwind relationships(path) as r " +
                "with sum(r.cost) as cost, path " +
                "return path " +
                "order by cost asc, length(path) asc limit 1", type, type);
        //执行查询
        Optional<TransportLineNodeDTO> optional = this.neo4jClient.query(cypherQuery)
                .bind(start.getBid()).to("startId")
                .bind(end.getBid()).to("endId")
                .fetchAs(TransportLineNodeDTO.class)
                .mappedBy(((typeSystem, record) -> {
                    PathValue pathValue = (PathValue) record.get(0);
                    Path path = pathValue.asPath();
                    TransportLineNodeDTO dto = new TransportLineNodeDTO();
                    //读取节点数据
                    path.nodes().forEach(node -> {
                        Map<String, Object> map = node.asMap();
                        OrganDTO organDTO = BeanUtil.toBeanIgnoreError(map, OrganDTO.class);
                        //设置type
                        organDTO.setType(OrganTypeEnum.valueOf(CollUtil.getFirst(node.labels())).getCode());
                        //设置经纬度
                        organDTO.setLongitude(BeanUtil.getProperty("location", "x"));
                        organDTO.setLatitude(BeanUtil.getProperty("location", "y"));
                        dto.getNodeList().add(organDTO);
                    });
                    //提取关系中的cost进行求和 计算总成本
                    path.relationships().forEach(relationship -> {
                        Map<String, Object> objectMap = relationship.asMap();
                        Double cost = Convert.toDouble(objectMap.get("cost"), 0d);
                        dto.setCost(NumberUtil.add(dto.getCost(), cost));
                    });
                    //保留2位小数
                    dto.setCost(NumberUtil.round(dto.getCost(), 2).doubleValue());
                    return dto;
                })).one();
        return optional.orElse(null);
    }


}
